一、实现过程
使用Pytorch进行预处理时,通常使用torchvision.transforms.Normalize(mean, std)方法进行数据标准化,其中参数mean和std分别表示图像集每个通道的均值和标准差序列。
首先,给出mean和std的定义,数学表示如下:
假设有一组数据集
X
i
,
i
∈
{
1
,
2
,
⋯
,
n
}
X_i,\,\,i\in\{1,2,\cdots,n\}
Xi,i∈{1,2,⋯,n},则这组数据集的均值为:
m
e
a
n
=
∑
i
=
1
n
X
i
n
(1)
mean=\frac{\displaystyle\sum_{i=1}^nX_i}{n}\tag{1}
mean=ni=1∑nXi(1)通常使用
X
‾
\overline X
X表示数据的均值。
这组数据集的标准差为:
s
t
d
=
∑
i
=
1
n
(
X
i
−
X
‾
)
2
n
=
∑
i
=
1
n
(
X
i
2
−
2
X
i
X
‾
+
X
‾
2
)
n
=
(
∑
i
=
1
n
X
i
2
)
−
n
X
‾
2
n
=
∑
i
=
1
n
X
i
2
n
−
X
‾
2
(2)
std=\sqrt{\frac{\displaystyle\sum_{i=1}^n\left(X_i-\overline X\right)^2}{n}}\\[2ex]=\sqrt{\frac{\displaystyle\sum_{i=1}^n(X_i^2-2X_i\overline X+\overline X^2)}{n}}\\[2ex]=\sqrt{\frac{\left(\displaystyle\sum_{i=1}^nX_i^2\right)-n\overline X^2}{n}}\\[2ex]=\sqrt{\frac{\displaystyle\sum_{i=1}^nX_i^2}{n}-\overline X^2}\tag{2}
std=ni=1∑n(Xi−X)2=ni=1∑n(Xi2−2XiX+X2)=n(i=1∑nXi2)−nX2=ni=1∑nXi2−X2(2)下面给出计算图像数据集每个通道的均值和标准差的函数代码:
import torch
from torchvision import transforms,datasets
from torch.utils.data import DataLoader
batch_size = 64
# 训练集(以CIFAR-10数据集为例)
train_dataset = datasets.CIFAR10(root='G:/datasets/cifar10',train=True,download=False,transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset,shuffle=True,batch_size=batch_size)
def get_mean_std_value(loader):
'''
求数据集的均值和标准差
:param loader:
:return:
'''
data_sum,data_squared_sum,num_batches = 0,0,0
for data,_ in loader:
# data: [batch_size,channels,height,width]
# 计算dim=0,2,3维度的均值和,dim=1为通道数量,不用参与计算
data_sum += torch.mean(data,dim=[0,2,3]) # [batch_size,channels,height,width]
# 计算dim=0,2,3维度的平方均值和,dim=1为通道数量,不用参与计算
data_squared_sum += torch.mean(data**2,dim=[0,2,3]) # [batch_size,channels,height,width]
# 统计batch的数量
num_batches += 1
# 计算均值
mean = data_sum/num_batches
# 计算标准差
std = (data_squared_sum/num_batches - mean**2)**0.5
return mean,std
mean,std = get_mean_std_value(train_loader)
print('mean = {},std = {}'.format(mean,std))
CIFAR10数据集的均值和标准差为:
mean = tensor([0.4914, 0.4821, 0.4465]),std = tensor([0.2470, 0.2435, 0.2616])
MNIST数据集的均值和标准差为:
mean = tensor([0.1307]),std = tensor([0.3081])